{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Question Answering Data Preparation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import sys\n",
    "\n",
    "import pandas\n",
    "import math\n",
    "sys.path.append('../..')\n",
    "from aips import *\n",
    "from IPython.display import HTML,display\n",
    "\n",
    "engine = get_engine()\n",
    "outdoors_collection = engine.get_collection(\"outdoors\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "NOTE: This notebook depends upon the Outdoors dataset. If you have any issues, please rerun the [Setting up the Outdoors Dataset](../ch13/1.setting-up-the-outdoors-dataset.ipynb) notebook."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Listing 14.4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_questions():\n",
    "    question_types = [\"who\", \"what\", \"when\",\n",
    "                      \"where\", \"why\", \"how\"]\n",
    "    questions = []\n",
    "    for type in question_types:\n",
    "        request = {\"query\": type,\n",
    "                   \"query_fields\": [\"title\"],\n",
    "                   \"return_fields\": [\"id\", \"url\", \"owner_user_id\",\n",
    "                                     \"title\", \"accepted_answer_id\"],\n",
    "                   \"filters\": [(\"accepted_answer_id\", \"*\")],\n",
    "                   \"limit\": 10000}\n",
    "        docs = outdoors_collection.search(**request)[\"docs\"]\n",
    "        questions += [document for document in docs #Only titles starting with a question type\n",
    "                      if document[\"title\"].lower().startswith(type)]\n",
    "    return questions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Listing 14.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_answers_from_questions(questions, batch_size=500):\n",
    "    answer_ids = list(set([str(q[\"accepted_answer_id\"])\n",
    "                           for q in questions]))\n",
    "    batches = math.ceil(len(answer_ids) / batch_size)\n",
    "    answers = {}\n",
    "    for n in range(0, batches):\n",
    "        ids = answer_ids[n * batch_size:(n + 1) * batch_size]\n",
    "        request = {\"query\": \"(\" + \" \".join(ids) + \")\",\n",
    "                   \"query_fields\": \"id\",\n",
    "                   \"limit\": batch_size,\n",
    "                   \"filters\": [(\"post_type\", \"answer\")],\n",
    "                   \"order_by\": [(\"score\", \"desc\")]}\n",
    "        docs = outdoors_collection.search(**request)[\"docs\"]\n",
    "        answers |= {int(d[\"id\"]): d[\"body\"] for d in docs}\n",
    "    return answers\n",
    "    \n",
    "def get_context_dataframe(questions):\n",
    "    answers = get_answers_from_questions(questions)\n",
    "    contexts = {\"id\": [], \"question\": [], \"context\": [], \"url\": []}\n",
    "    for question in questions:\n",
    "        contexts[\"id\"].append(question[\"id\"])\n",
    "        contexts[\"url\"].append(question[\"url\"])\n",
    "        contexts[\"question\"].append(question[\"title\"]),\n",
    "        if question[\"accepted_answer_id\"] in answers:\n",
    "            context = answers[question[\"accepted_answer_id\"]]\n",
    "        else:\n",
    "            context = \"Not found\"\n",
    "        contexts[\"context\"].append(context)\n",
    "    return pandas.DataFrame(contexts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>question</th>\n",
       "      <th>context</th>\n",
       "      <th>url</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>4410</td>\n",
       "      <td>Who places the anchors that rock climbers use?</td>\n",
       "      <td>There are two distinct styles of free rock cli...</td>\n",
       "      <td>https://outdoors.stackexchange.com/questions/4410</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>5347</td>\n",
       "      <td>Who places the bolts on rock climbing routes, ...</td>\n",
       "      <td>What you're talking about is Sport climbing. G...</td>\n",
       "      <td>https://outdoors.stackexchange.com/questions/5347</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>20662</td>\n",
       "      <td>Who gets the bill if you activate a PLB to hel...</td>\n",
       "      <td>Almost always the victim gets the bill, but as...</td>\n",
       "      <td>https://outdoors.stackexchange.com/questions/2...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>11587</td>\n",
       "      <td>What sort of crane, and what sort of snake?</td>\n",
       "      <td>To answer the snake part of it, looking at som...</td>\n",
       "      <td>https://outdoors.stackexchange.com/questions/1...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>7623</td>\n",
       "      <td>What knot is this one? What are its purposes?</td>\n",
       "      <td>Slip knot It's undoubtably a slip knot that's ...</td>\n",
       "      <td>https://outdoors.stackexchange.com/questions/7623</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      id                                           question  \\\n",
       "0   4410     Who places the anchors that rock climbers use?   \n",
       "1   5347  Who places the bolts on rock climbing routes, ...   \n",
       "2  20662  Who gets the bill if you activate a PLB to hel...   \n",
       "3  11587        What sort of crane, and what sort of snake?   \n",
       "4   7623      What knot is this one? What are its purposes?   \n",
       "\n",
       "                                             context  \\\n",
       "0  There are two distinct styles of free rock cli...   \n",
       "1  What you're talking about is Sport climbing. G...   \n",
       "2  Almost always the victim gets the bill, but as...   \n",
       "3  To answer the snake part of it, looking at som...   \n",
       "4  Slip knot It's undoubtably a slip knot that's ...   \n",
       "\n",
       "                                                 url  \n",
       "0  https://outdoors.stackexchange.com/questions/4410  \n",
       "1  https://outdoors.stackexchange.com/questions/5347  \n",
       "2  https://outdoors.stackexchange.com/questions/2...  \n",
       "3  https://outdoors.stackexchange.com/questions/1...  \n",
       "4  https://outdoors.stackexchange.com/questions/7623  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "questions = get_questions()\n",
    "contexts = get_context_dataframe(questions)\n",
    "display(contexts[0:5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "contexts.to_csv(\"data/outdoors/qa-seed-contexts.csv\", index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Listing 14.6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "def get_processor_device(): \n",
    "    return 0 if torch.cuda.is_available() else -1\n",
    "\n",
    "def display_guesses(guesses):\n",
    "    display(HTML(pandas.DataFrame(guesses[0:10]).to_html(index=False)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import pipeline \n",
    "import tqdm\n",
    "\n",
    "model_name = \"deepset/roberta-base-squad2\"\n",
    "device = get_processor_device()\n",
    "\n",
    "def answer_questions(contexts, k=10):\n",
    "    nlp = pipeline(\"question-answering\", model=model_name,\n",
    "                   tokenizer=model_name, device=device)\n",
    "    guesses = []\n",
    "    for _, row in tqdm.tqdm(contexts[0:k].iterrows(), total=k):\n",
    "        result = nlp({\"question\": row[\"question\"],\n",
    "                      \"context\": row[\"context\"]})\n",
    "        guesses.append(result)\n",
    "    return guesses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/1663 [00:00<?, ?it/s]/opt/conda/lib/python3.10/site-packages/transformers/pipelines/question_answering.py:391: FutureWarning: Passing a list of SQuAD examples to the pipeline is deprecated and will be removed in v5. Inputs should be passed using the `question` and `context` keyword arguments instead.\n",
      "  warnings.warn(\n",
      "100%|██████████| 1663/1663 [14:27<00:00,  1.92it/s]\n"
     ]
    }
   ],
   "source": [
    "guesses = answer_questions(contexts, k=len(contexts))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th>score</th>\n",
       "      <th>start</th>\n",
       "      <th>end</th>\n",
       "      <th>answer</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0.278927</td>\n",
       "      <td>474</td>\n",
       "      <td>516</td>\n",
       "      <td>a local enthusiast or group of enthusiasts</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>0.200848</td>\n",
       "      <td>81</td>\n",
       "      <td>117</td>\n",
       "      <td>the person who is creating the climb</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>0.018632</td>\n",
       "      <td>14</td>\n",
       "      <td>24</td>\n",
       "      <td>the victim</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>0.000551</td>\n",
       "      <td>1255</td>\n",
       "      <td>1262</td>\n",
       "      <td>aquatic</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>0.222317</td>\n",
       "      <td>29</td>\n",
       "      <td>38</td>\n",
       "      <td>slip knot</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>0.374997</td>\n",
       "      <td>15</td>\n",
       "      <td>40</td>\n",
       "      <td>a high-tech treasure hunt</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>0.000087</td>\n",
       "      <td>695</td>\n",
       "      <td>706</td>\n",
       "      <td>water vapor</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>0.247008</td>\n",
       "      <td>227</td>\n",
       "      <td>265</td>\n",
       "      <td>the traditional longbow made from wood</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>0.480407</td>\n",
       "      <td>408</td>\n",
       "      <td>473</td>\n",
       "      <td>shoes intended to closely approximate barefoot running conditions</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>0.050992</td>\n",
       "      <td>52</td>\n",
       "      <td>66</td>\n",
       "      <td>Tree of Heaven</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "display_guesses(guesses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "contexts[\"answers\"] = guesses\n",
    "contexts.to_csv(\"data/outdoors/qa-squad2-guesses.csv\", index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Listing 14.7"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ***Manually labeling data**\n",
    "*The above csv file (data/question-answer-squad2-guesses.csv) is used as a raw first pass at attempting to answer the questions.  This is then used with human-in-the-loop manual correction and labelling of the data.  There is no python code that can do this for you.  The data MUST be labelled by an intelligent person with an understanding of the domain.  All further listings will use the 'golden set' - the manually corrected answer file, and not the guesses that were generated above.*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['id', 'url', 'title', 'question', 'context', 'answers', '__index_level_0__'],\n",
       "        num_rows: 1243\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['id', 'url', 'title', 'question', 'context', 'answers', '__index_level_0__'],\n",
       "        num_rows: 331\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['id', 'url', 'title', 'question', 'context', 'answers', '__index_level_0__'],\n",
       "        num_rows: 84\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import Dataset, DatasetDict\n",
    "random.seed(0)\n",
    "\n",
    "def get_training_data(filename):\n",
    "    golden_answers = pandas.read_csv(filename)\n",
    "    golden_answers = golden_answers[golden_answers[\"class\"] != None]\n",
    "    qa_data = []\n",
    "    for _, row in golden_answers.iterrows():\n",
    "        answers = row[\"gold\"].split(\"|\")\n",
    "        starts = [row[\"context\"].find(a) for a in answers]\n",
    "        missing = -1 in starts\n",
    "        if not missing:\n",
    "            row[\"title\"] = row[\"question\"]\n",
    "            row[\"answers\"] = {\"text\": answers, \"answer_start\": starts}\n",
    "            qa_data.append(row)\n",
    "    columns = [\"id\", \"url\", \"title\", \"question\", \"context\", \"answers\"]\n",
    "    df = pandas.DataFrame(qa_data, columns=columns).sample(frac=1)\n",
    "    train_split = int(len(df) * 0.75)\n",
    "    eval_split = (int((len(df) - train_split) / 1.25) +\n",
    "                  train_split - 1)\n",
    "    train_dataset = Dataset.from_pandas(df[:train_split])\n",
    "    test_dataset = Dataset.from_pandas(df[train_split:eval_split])\n",
    "    validation_dataset = Dataset.from_pandas(df[eval_split:])\n",
    "    return DatasetDict({\"train\": train_dataset, \"test\": test_dataset,\n",
    "                        \"validation\": validation_dataset})\n",
    "\n",
    "#This golden answers file was labeled by me (Max Irwin).\n",
    "#It took about 2-3 hours to label 200 question/answer rows\n",
    "#Doing so will give you a deeper appreciation for the difficulty of the NLP task.\n",
    "#I *highly* encourage you to label even more documents, and re-run the fine-tuning tasks coming up.\n",
    "datadict = get_training_data(\"data/outdoors/outdoors_golden_answers.csv\")\n",
    "model_path = \"data/question-answering/question-answering-training-set\"\n",
    "\n",
    "#datadict.save_to_disk(model_path)\n",
    "\n",
    "datadict"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Up next: [Question Answering LLM Fine-tuning](3.question-answering-fine-tuning.ipynb)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "vscode": {
   "interpreter": {
    "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
